# encoding: utf-8

from yade import pack, export, plot
import math, os, sys, shutil, subprocess, tempfile
from os import path
from xml.etree import ElementTree as ET
import unittest

print("checkVTKRecorder")

#### This is useful for printing the linenumber in the script
# import inspect
# print(inspect.currentframe().f_lineno)

from yade import pack
import yade.libVersions


class TestVTKRecorder(unittest.TestCase):

	@classmethod
	def setUpClass(self):
		self.delta = 1e-7
		self.supported_vtk_versions = [6, 8, 9]
		self.fileList = [
		        "10.vtm",
		        path.join("10", "10_0.vtu"),
		        path.join("10", "10_2.vtp"),
		]
		self.toSkip = ["dirI", "dirII", "dirIII"]
		# an analog of max skippedLines
		self.max_skip_count = 10

		self.vtkVer = yade.libVersions.getVersion("vtk")
		self.verName = None
		self.checksPath = globals().get("checksPath", path.dirname(__file__))
		self.vtk_save_dir = path.join(tempfile.mkdtemp(), "vtk_testing/")
		if self.vtkVer is not None:
			self.setup_vtk_vars()

	@classmethod
	def setup_vtk_vars(self):
		self.verName = self.vtkVer[0]
		if self.vtkVer[0] == 6:
			self.fileList.append(path.join("10", "10_1.vtu"))
		elif self.vtkVer[0] == 8:
			self.verName += ".1" if self.vtkVer[1] < 2 else "0.2"
		os.makedirs(self.vtk_save_dir, exist_ok=True)

	@classmethod
	def ppc64elLongDouble(self):
		return (yade.libVersions.getArchitecture() == "ppc64el") and (yade.math.RealHPConfig.getDigits10(1) == 31)

	def compare_text(
	        self,
	        refer_text,
	        test_text,
	        num_type="Int",
	):
		# Must be None or not none at the same time,
		# skip to compare values when both are None
		self.assertEqual(refer_text is None, test_text is None)
		if refer_text is None and test_text is None:
			return

		refer_numbers = refer_text.split()
		test_numbers = test_text.split()
		self.assertEqual(
		        len(refer_numbers),
		        len(test_numbers),
		        msg=f"test failed with test_text {test_text}",
		)
		if num_type.startswith("Float"):
			for k1, k2 in zip(refer_numbers, test_numbers):
				self.assertAlmostEqual(float(k1), float(k2), delta=self.delta)
		else:
			for k1, k2 in zip(refer_numbers, test_numbers):
				self.assertEqual(k1, k2)

	def compare_xml(self, reference_xml, test_xml):
		"""
        Test whether two xml files are almost the same. Parameter positions matter.
        reference_xml: xml files shipped with yade release,
        test_xml: xml file generated by yade.
        """
		skip_count = 0
		refer_iter = reference_xml.iter()
		test_iter = test_xml.iter()
		refer_child = next(refer_iter, None)
		test_child = next(test_iter, None)
		while refer_child != None and test_child != None:
			# compare tag
			self.assertEqual(refer_child.tag, test_child.tag)

			# test if keys match, if keys are not the same, skip a child in test file
			# fixes: https://gitlab.com/yade-dev/trunk/-/issues/342
			skip_flag = False
			for k in refer_child.keys():
				if k not in test_child.keys():
					# key mismatch, skip
					test_child = next(test_iter, None)
					skip_flag = True
					skip_count += 1
					break
			if skip_flag:
				continue

			skip_compare_text = False

			data_type = "None"
			# start to compare header values and text, most fields have `type` announced,
			# but there are also `<Value index="0">` sections that does not have a `type`
			if refer_child.tag == "Value":
				data_type = "Float64"

			# compare attributes
			for k, v in refer_child.items():
				# allow test file to have keys not in reference file
				v_t = test_child.get(k, None)
				if k == "type":
					# following values may be numbers
					data_type = v
				elif k == "Name" and v in self.toSkip:
					skip_compare_text = True
				if k.startswith("Range") and data_type.startswith("Float"):
					# here comes float numbers
					vv = float(v)
					tt = float(test_child.get(k, "Infinity"))
					self.assertAlmostEqual(vv, tt)
				else:
					self.assertEqual(v_t, v)

			# compare text
			if not skip_compare_text:
				self.compare_text(refer_child.text, test_child.text, num_type=data_type)

			refer_child = next(refer_iter, None)
			test_child = next(test_iter, None)

		# finally both are None
		self.assertEqual(refer_child, None)
		self.assertEqual(test_child, None)
		# we don't skip too many lines
		self.assertLessEqual(skip_count, self.max_skip_count)

	def generate_test_files(self):
		checksPath = self.checksPath
		O.periodic = False
		length = 1.0
		height = 1.0
		width = 1.0
		thickness = 0.1

		O.materials.append(FrictMat(
		        density=1,
		        young=1e5,
		        poisson=0.3,
		        frictionAngle=radians(30),
		        label="boxMat",
		))
		lowBox = box(
		        center=(length / 2.0, thickness * 0.6, width / 2.0),
		        extents=(length * 2.0, thickness / 2.0, width * 2.0),
		        fixed=True,
		        wire=False,
		)
		O.bodies.append(lowBox)

		radius = 0.01
		O.materials.append(FrictMat(
		        density=1000,
		        young=1e4,
		        poisson=0.3,
		        frictionAngle=radians(30),
		        label="sphereMat",
		))
		sp = pack.SpherePack()
		# sp.makeCloud((0.*length,height+1.2*radius,0.25*width),(0.5*length,2*height-1.2*radius,0.75*width),-1,.2,2000,periodic=True)
		sp.load(path.join(checksPath, "data", "100spheres"))
		# 100 was not enough to have reasonable number of collisions, so I put 200 spheres.
		O.bodies.append([sphere(s[0] + Vector3(0.0, 0.2, 0.0), s[1]) for s in sp])
		O.bodies.append([sphere(s[0] + Vector3(0.1, 0.3, 0.0), s[1]) for s in sp])

		O.dt = 5e-4
		O.usesTimeStepper = False
		newton = NewtonIntegrator(damping=0.6, gravity=(0, -10, 0))

		O.engines = [
		        ForceResetter(),
		        # (1) This is where we allow big bodies, else it would crash due to the very large bottom box:
		        InsertionSortCollider([Bo1_Box_Aabb(), Bo1_Sphere_Aabb()], allowBiggerThanPeriod=True),
		        InteractionLoop(
		                [Ig2_Sphere_Sphere_ScGeom(), Ig2_Box_Sphere_ScGeom()],
		                [Ip2_FrictMat_FrictMat_FrictPhys()],
		                [Law2_ScGeom_FrictPhys_CundallStrack()],
		        ),
		        VTKRecorder(
		                fileName=self.vtk_save_dir,
		                recorders=["all"],
		                firstIterRun=10,
		                iterPeriod=2000,
		                label="VtkRecorder",
		                ascii=True,
		                multiblock=True,
		        ),
		        newton,
		]

		for b in O.bodies:
			b.shape.color = Vector3(b.id % 8 / 8.0, b.id % 8 / 8.0, 0.5)

		O.run(20, True)

	def test_args(self):
		arg_fail_msg = (
		        "This test will only work on single core, because it must be fully reproducible, "
		        f"but -j {opts.threads} or --cores {opts.cores} is used."
		)
		self.assertTrue(opts.threads is None or opts.threads == 1, msg=arg_fail_msg)
		self.assertTrue(opts.cores is None or opts.cores == 1, msg=arg_fail_msg)

	def test_arch(self):
		if self.ppc64elLongDouble():
			print("skip VTKRecorder check on ppc64el architecture compiled with long double.")
			return
		else:
			if "VTK" not in features:
				print("skip VTKRecorder check, VTK is not available")
				return
		self.check_vtk()

	def test_vtk_ver(self):
		fail_msg = (
		        f"checkVTKRecorder does not have reference results for VTK version {self.vtkVer}, "
		        f"check files in {self.vtk_save_dir}, "
		        f"if they are correct add them to: scripts/checks-and-tests/checks/data/vtk_reference_{self.verName}/"
		)
		if self.vtkVer is not None:
			self.assertTrue(self.vtkVer[0] in self.supported_vtk_versions, msg=fail_msg)

	def check_vtk(self):
		test_dir = self.vtk_save_dir
		reference_dir = path.join(self.checksPath, "data", f"vtk_reference_{self.verName}")

		self.generate_test_files()
		for target_file in self.fileList:
			print("Check file", target_file)
			test_XML = ET.parse(path.join(test_dir, target_file)).getroot()
			reference_XML = ET.parse(path.join(reference_dir, target_file)).getroot()
			self.compare_xml(reference_XML, test_XML)


class VTKRecorderTestResult(unittest.TextTestResult):

	def addError(self, test, err):
		super().addError(test, err)
		raise err[1]  # Raise the exception when an error occurs

	def addFailure(self, test, err):
		super().addFailure(test, err)
		raise err[1]  # Raise the exception when a failure occurs


class VTKRecorderTestRunner(unittest.TextTestRunner):

	def _makeResult(self):
		return VTKRecorderTestResult(self.stream, self.descriptions, self.verbosity)


if __name__ == "__main__":
	runner = VTKRecorderTestRunner()
	suite = unittest.TestSuite()
	test_case = unittest.TestLoader().loadTestsFromTestCase(TestVTKRecorder)
	suite.addTest(test_case)
	runner.run(suite)
